
package io.gitee.h25094152.crypto.rsa;

import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.security.KeyFactory;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.spec.RSAPrivateCrtKeySpec;
import java.security.spec.RSAPublicKeySpec;
import java.util.Base64;

/**
 * @author wh
 * .net 生成的密钥转成Java可用
 */
public class RSAXmlToPem {

    private static final int PRIVATE_KEY = 1;
    private static final int PUBLIC_KEY = 2;
    private static final String[] PRIVATE_KEY_XML_NODES = {"Modulus", "Exponent", "P", "Q", "DP", "DQ", "InverseQ", "D"};
    private static final String[] PUBLIC_KEY_XML_NODES = {"Modulus", "Exponent"};

    /**
     * xml格式秘钥转为Pem格式秘钥
     * * @param xmlStr  xml格式私钥字符串
     * * @return pem格式秘钥
     */
    public static String transXmlStrToPem(String xmlStr) {
        String pem = "";
        try {
            Document XMLSecKeyDoc = parseXMLStr(xmlStr);
            int keyType = getKeyType(XMLSecKeyDoc);
            if (keyType == PRIVATE_KEY) {
                pem = convertXMLRSAPrivateKeyToPEM(XMLSecKeyDoc);
            } else {
                pem = convertXMLRSAPublicKeyToPEM(XMLSecKeyDoc);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        //System.out.println(pem);
        return pem;
    }

    /**
     * xml格式秘钥转为Pem格式秘钥
     * * @param path  xml格式私钥文件地址
     * * @return pem格式秘钥
     */
    public static String transXmlTotPem(String path) {
        String pem = "";
        try {
            Document XMLSecKeyDoc = parseXMLFile(path);
            int keyType = getKeyType(XMLSecKeyDoc);
            if (keyType == PRIVATE_KEY) {
                pem = convertXMLRSAPrivateKeyToPEM(XMLSecKeyDoc);
            } else {
                pem = convertXMLRSAPublicKeyToPEM(XMLSecKeyDoc);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return pem;
    }

    /**
     * 获得秘钥类型
     * * @param xmldoc
     * * @return
     */
    private static int getKeyType(Document xmldoc) {
        Node root = xmldoc.getFirstChild();
        if (!root.getNodeName().equals("RSAKeyValue")) {
            //System.out.println("Expecting <RSAKeyValue> node, encountered <" + root.getNodeName() + ">");
            return 0;
        }
        NodeList children = root.getChildNodes();
        if (children.getLength() == PUBLIC_KEY_XML_NODES.length) {
            return PUBLIC_KEY;
        }
        return PRIVATE_KEY;
    }

    /**
     * 检查秘钥是否符合格式
     * @param keyType
     * @param xmldoc
     * @return
     */
    private static boolean checkXMLRSAKey(int keyType, Document xmldoc) {
        Node root = xmldoc.getFirstChild();
        NodeList children = root.getChildNodes();
        String[] wantedNodes = {};
        if (keyType == PRIVATE_KEY) {
            wantedNodes = PRIVATE_KEY_XML_NODES;
        } else {
            wantedNodes = PUBLIC_KEY_XML_NODES;
        }
        for (int j = 0; j < wantedNodes.length; j++) {
            String wantedNode = wantedNodes[j];
            boolean found = false;
            for (int i = 0; i < children.getLength(); i++) {
                if (children.item(i).getNodeName().equals(wantedNode)) {
                    found = true;
                    break;
                }
            }
            if (!found) {
                //System.out.println("Cannot find node <" + wantedNode + ">");
                return false;
            }
        }
        return true;
    }

    /**
     * 将xml格式私钥转为PEM格式
     * * @param xmldoc
     * * @return
     */
    private static String convertXMLRSAPrivateKeyToPEM(Document xmldoc) {
        Node root = xmldoc.getFirstChild();
        NodeList children = root.getChildNodes();
        BigInteger modulus = null, exponent = null, primeP = null, primeQ = null, primeExponentP = null, primeExponentQ = null, crtCoefficient = null, privateExponent = null;
        for (int i = 0; i < children.getLength(); i++) {
            Node node = children.item(i);
            String textValue = node.getTextContent();
            if (node.getNodeName().equals("Modulus")) {
                modulus = new BigInteger(b64decode(textValue));
            } else if (node.getNodeName().equals("Exponent")) {
                exponent = new BigInteger(b64decode(textValue));
            } else if (node.getNodeName().equals("P")) {
                primeP = new BigInteger(b64decode(textValue));
            } else if (node.getNodeName().equals("Q")) {
                primeQ = new BigInteger(b64decode(textValue));
            } else if (node.getNodeName().equals("DP")) {
                primeExponentP = new BigInteger(b64decode(textValue));
            } else if (node.getNodeName().equals("DQ")) {
                primeExponentQ = new BigInteger(b64decode(textValue));
            } else if (node.getNodeName().equals("InverseQ")) {
                crtCoefficient = new BigInteger(b64decode(textValue));
            } else if (node.getNodeName().equals("D")) {
                privateExponent = new BigInteger(b64decode(textValue));
            }
        }
        try {
            RSAPrivateCrtKeySpec keySpec = new RSAPrivateCrtKeySpec(modulus, exponent, privateExponent, primeP, primeQ, primeExponentP, primeExponentQ, crtCoefficient);
            KeyFactory keyFactory = KeyFactory.getInstance("RSA");
            PrivateKey key = keyFactory.generatePrivate(keySpec);
            return b64encode(key.getEncoded());
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * 将xml格式公钥转为PEM格式
     * * @param xmldoc
     * * @return
     */
    private static String convertXMLRSAPublicKeyToPEM(Document xmldoc) {
        Node root = xmldoc.getFirstChild();
        NodeList children = root.getChildNodes();
        BigInteger modulus = null, exponent = null;
        for (int i = 0; i < children.getLength(); i++) {
            Node node = children.item(i);
            String textValue = node.getTextContent();
            if (node.getNodeName().equals("Modulus")) {
                modulus = new BigInteger(b64decode(textValue));
            } else if (node.getNodeName().equals("Exponent")) {
                exponent = new BigInteger(b64decode(textValue));
            }
        }
        try {
            RSAPublicKeySpec keySpec = new RSAPublicKeySpec(modulus, exponent);
            KeyFactory keyFactory = KeyFactory.getInstance("RSA");
            PublicKey key = keyFactory.generatePublic(keySpec);
            return b64encode(key.getEncoded());
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * 将xml格式字符串格式化城Document格式
     * * @param xmlStr
     * * @return
     */
    private static Document parseXMLStr(String xmlStr) {
        try {
            InputStream in = new ByteArrayInputStream(xmlStr.getBytes());
            DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
            DocumentBuilder builder = factory.newDocumentBuilder();
            Document document = builder.parse(in);
            return document;
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    /**
     * 将某个包含xml的文件转换为Document格式
     * * @param filename
     * * @return
     */
    private static Document parseXMLFile(String filename) {
        try {
            DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
            DocumentBuilder builder = factory.newDocumentBuilder();
            Document document = builder.parse(new File(filename));
            return document;
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    private static final String b64encode(byte[] data) {
        //sun.misc.BASE64Encoder enc = new sun.misc.BASE64Encoder();

        String b64str = Base64.getEncoder().encodeToString(data).trim();
        return b64str;
    }

    private static final byte[] b64decode(String data) {
        try {

            byte[] bytes = Base64.getDecoder().decode(data.trim());
            return bytes;
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

}
